{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "06cff9e4",
   "metadata": {},
   "source": [
    "# BGE Series"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "880e229d",
   "metadata": {},
   "source": [
    "In this Part, we will walk through the BGE series and introduce how to use the BGE embedding models."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2516fd49",
   "metadata": {},
   "source": [
    "## 1. BAAI General Embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2113ee71",
   "metadata": {},
   "source": [
    "BGE stands for BAAI General Embedding, it's a series of embeddings models developed and published by Beijing Academy of Artificial Intelligence (BAAI)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16515b99",
   "metadata": {},
   "source": [
    "A full support of APIs and related usages of BGE is maintained in [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) on GitHub.\n",
    "\n",
    "Run the following cell to install FlagEmbedding in your environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88095fd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "%pip install -U FlagEmbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc6e30a0",
   "metadata": {},
   "source": [
    "The collection of BGE models can be found in [Huggingface collection](https://huggingface.co/collections/BAAI/bge-66797a74476eb1f085c7446d)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67a16ccf",
   "metadata": {},
   "source": [
    "## 2. BGE Series Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e10034a",
   "metadata": {},
   "source": [
    "### 2.1 BGE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cdc6702",
   "metadata": {},
   "source": [
    "The very first version of BGE has 6 models, with 'large', 'base', and 'small' for English and Chinese. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04b75f72",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en)   | English |    335M    |    1.34 GB   |              Embedding Model which map text into vector                            |  BERT  |\n",
    "| [BAAI/bge-base-en](https://huggingface.co/BAAI/bge-base-en)     | English |    109M    |    438 MB    |          a base-scale model but with similar ability to `bge-large-en`  |  BERT  |\n",
    "| [BAAI/bge-small-en](https://huggingface.co/BAAI/bge-small-en)   | English |    33.4M   |    133 MB    |          a small-scale model but with competitive performance                    |  BERT  |\n",
    "| [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)   | Chinese |    326M    |    1.3 GB    |              Embedding Model which map text into vector                            |  BERT  |\n",
    "| [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)     | Chinese |    102M    |    409 MB    |           a base-scale model but with similar ability to `bge-large-zh`           |  BERT  |\n",
    "| [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)   | Chinese |    24M     |    95.8 MB   |           a small-scale model but with competitive performance                    |  BERT  |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9c45d17",
   "metadata": {},
   "source": [
    "For inference, import FlagModel from FlagEmbedding and initialize the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e07751",
   "metadata": {},
   "outputs": [],
   "source": [
    "from FlagEmbedding import FlagModel\n",
    "\n",
    "# Load BGE model\n",
    "model = FlagModel('BAAI/bge-base-en',\n",
    "                  query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
    "                  use_fp16=True)\n",
    "\n",
    "queries = [\"query 1\", \"query 2\"]\n",
    "corpus = [\"passage 1\", \"passage 2\"]\n",
    "\n",
    "# encode the queries and corpus\n",
    "q_embeddings = model.encode(queries)\n",
    "p_embeddings = model.encode(corpus)\n",
    "\n",
    "# compute the similarity scores\n",
    "scores = q_embeddings @ p_embeddings.T\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c8e69ed",
   "metadata": {},
   "source": [
    "To use `FlagModel`:\n",
    "```\n",
    "FlagModel.encode(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n",
    "```\n",
    "The *encode()* function directly encode the input sentences to embedding vectors.\n",
    "```\n",
    "FlagModel.encode_queries(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n",
    "```\n",
    "The *encode_queries()* function concatenate the `query_instruction_for_retrieval` with each of the input query, and then call `encode()`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c86a5a3",
   "metadata": {},
   "source": [
    "### 2.2 BGE v1.5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "454ff7aa",
   "metadata": {},
   "source": [
    "BGE 1.5 alleviate the issue of the similarity distribution, and enhance retrieval ability without instruction."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30b1f897",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)   | English |    335M    |    1.34 GB   |     version 1.5 with more reasonable similarity distribution      |   BERT   |\n",
    "| [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5)     | English |    109M    |    438 MB    |     version 1.5 with more reasonable similarity distribution      |   BERT   |\n",
    "| [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5)   | English |    33.4M   |    133 MB    |     version 1.5 with more reasonable similarity distribution      |   BERT   |\n",
    "| [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5)   | Chinese |    326M    |    1.3 GB    |     version 1.5 with more reasonable similarity distribution      |   BERT   |\n",
    "| [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5)     | Chinese |    102M    |    409 MB    |     version 1.5 with more reasonable similarity distribution      |   BERT   |\n",
    "| [BAAI/bge-small-zh-v1.5](https://huggingface.co/BAAI/bge-small-zh-v1.5)   | Chinese |    24M     |    95.8 MB   |     version 1.5 with more reasonable similarity distribution      |   BERT   |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed00c504",
   "metadata": {},
   "source": [
    "BGE 1.5 models shares the same API of `FlagModel` with BGE models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9b17afcc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.736794  0.5989914]\n",
      " [0.5684842 0.7461165]]\n"
     ]
    }
   ],
   "source": [
    "model = FlagModel('BAAI/bge-base-en-v1.5',\n",
    "                  query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n",
    "                  use_fp16=True)\n",
    "\n",
    "queries = [\"query 1\", \"query 2\"]\n",
    "corpus = [\"passage 1\", \"passage 2\"]\n",
    "\n",
    "# encode the queries and corpus\n",
    "q_embeddings = model.encode(queries)\n",
    "p_embeddings = model.encode(corpus)\n",
    "\n",
    "# compute the similarity scores\n",
    "scores = q_embeddings @ p_embeddings.T\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38c3ce1c",
   "metadata": {},
   "source": [
    "### 2.3 LLM-Embedder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bc3fee0",
   "metadata": {},
   "source": [
    "LLM-Embedder is a unified embedding model supporting diverse retrieval augmentation needs for LLMs. It is fine-tuned over 6 tasks:\n",
    "- Question Answering (qa)\n",
    "- Conversational Search (convsearch)\n",
    "- Long Conversation (chat)\n",
    "- Long-Rnage Language Modeling (lrlm)\n",
    "- In-Context Learning (icl)\n",
    "- Tool Learning (tool)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13b926e9",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/llm-embedder](https://huggingface.co/BAAI/llm-embedder)             |   English | 109M |  438 MB  |      a unified embedding model to support diverse retrieval augmentation needs for LLMs       | BERT |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7b3f109",
   "metadata": {},
   "source": [
    "To use `LLMEmbedder`:\n",
    "```python\n",
    "LLMEmbedder.encode_queries(\n",
    "    queries, \n",
    "    batch_size=256, \n",
    "    max_length=256, \n",
    "    task='qa'\n",
    ")\n",
    "```\n",
    "The *encode_queries()* will call the *_encode()* functions (similar to the *encode()* in `FlagModel`) and add the corresponding query instruction of the given *task* in front of each of the input *queries*.\n",
    "```python\n",
    "LLMEmbedder.encode_keys(\n",
    "    keys, \n",
    "    batch_size=256, \n",
    "    max_length=512, \n",
    "    task='qa'\n",
    ")\n",
    "```\n",
    "Similarly, *encode_keys()* also calls *_encode()* and automatically add instructions according to given task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5f077420",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.89705944 0.85341793]\n",
      " [0.8462474  0.90914035]]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import LLMEmbedder\n",
    "\n",
    "# load the LLMEmbedder model\n",
    "model = LLMEmbedder('BAAI/llm-embedder', use_fp16=False)\n",
    "\n",
    "# Define queries and keys\n",
    "queries = [\"test query 1\", \"test query 2\"]\n",
    "keys = [\"test key 1\", \"test key 2\"]\n",
    "\n",
    "# Encode for a specific task (qa, icl, chat, lrlm, tool, convsearch)\n",
    "task = \"qa\"\n",
    "query_embeddings = model.encode_queries(queries, task=task)\n",
    "key_embeddings = model.encode_keys(keys, task=task)\n",
    "\n",
    "# compute the similarity scores\n",
    "similarity = query_embeddings @ key_embeddings.T\n",
    "print(similarity)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcf2a82b",
   "metadata": {},
   "source": [
    "### 2.4 BGE M3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc5b5a5e",
   "metadata": {},
   "source": [
    "BGE-M3 is the new version of BGE models that is distinguished for its versatility in:\n",
    "- Multi-Functionality: Simultaneously perform the three common retrieval functionalities of embedding model: dense retrieval, multi-vector retrieval, and sparse retrieval.\n",
    "- Multi-Linguality: Supports more than 100 working languages.\n",
    "- Multi-Granularity: Can proces inputs with different granularityies, spanning from short sentences to long documents of up to 8192 tokens.\n",
    "\n",
    "For more details, feel free to check out the [paper](https://arxiv.org/pdf/2402.03216)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41348e03",
   "metadata": {},
   "source": [
    "| Model  | Language |   Parameters   |   Model Size   |    Description    |   Base Model     |\n",
    "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n",
    "| [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)                   |    Multilingual     |   568M   |  2.27 GB  |  Multi-Functionality(dense retrieval, sparse retrieval, multi-vector(colbert)), Multi-Linguality, and Multi-Granularity(8192 tokens) | XLM-RoBERTa |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d4647625",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 228780.22it/s]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import BGEM3FlagModel\n",
    "\n",
    "model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)\n",
    "\n",
    "sentences = [\"What is BGE M3?\", \"Defination of BM25\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f89f1a9",
   "metadata": {},
   "source": [
    "```python\n",
    "BGEM3FlagModel.encode(\n",
    "    sentences, \n",
    "    batch_size=12, \n",
    "    max_length=8192, \n",
    "    return_dense=True, \n",
    "    return_sparse=False, \n",
    "    return_colbert_vecs=False\n",
    ")\n",
    "```\n",
    "It returns a dictionary like:\n",
    "```python\n",
    "{\n",
    "    'dense_vecs': 'array of dense embeddings of inputs if return_dense=True, otherwise None,'\n",
    "    'lexical_weights': 'array of dictionaries with keys and values are ids of tokens and their corresponding weights if return_sparse=True, otherwise None,'\n",
    "    'colbert_vecs': 'array of multi-vector embeddings of inputs if return_cobert_vecs=True, otherwise None,'\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f0b11cf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you don't need such a long length of 8192 input tokens, you can set max_length to a smaller value to speed up encoding.\n",
    "embeddings = model.encode(\n",
    "    sentences, \n",
    "    max_length=10,\n",
    "    return_dense=True, \n",
    "    return_sparse=True, \n",
    "    return_colbert_vecs=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "72cba126",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dense embedding:\n",
      "[[-0.03411707 -0.04707828 -0.00089447 ...  0.04828531  0.00755427\n",
      "  -0.02961654]\n",
      " [-0.01041734 -0.04479263 -0.02429199 ... -0.00819298  0.01503995\n",
      "   0.01113793]]\n",
      "sparse embedding:\n",
      "[defaultdict(<class 'int'>, {'4865': 0.08362077, '83': 0.081469566, '335': 0.12964639, '11679': 0.25186998, '276': 0.17001738, '363': 0.26957875, '32': 0.040755156}), defaultdict(<class 'int'>, {'262': 0.050144322, '5983': 0.13689369, '2320': 0.045134712, '111': 0.06342201, '90017': 0.25167602, '2588': 0.33353207})]\n",
      "multi-vector:\n",
      "[array([[-8.6726490e-03, -4.8921868e-02, -3.0449261e-03, ...,\n",
      "        -2.2082448e-02,  5.7268854e-02,  1.2811369e-02],\n",
      "       [-8.8765034e-03, -4.6860173e-02, -9.5845405e-03, ...,\n",
      "        -3.1404708e-02,  5.3911421e-02,  6.8714428e-03],\n",
      "       [ 1.8445771e-02, -4.2359587e-02,  8.6754939e-04, ...,\n",
      "        -1.9803897e-02,  3.8384371e-02,  7.6852231e-03],\n",
      "       ...,\n",
      "       [-2.5543230e-02, -1.6561864e-02, -4.2125367e-02, ...,\n",
      "        -4.5030322e-02,  4.4091221e-02, -1.0043185e-02],\n",
      "       [ 4.9905590e-05, -5.5475257e-02,  8.4884483e-03, ...,\n",
      "        -2.2911752e-02,  6.0379632e-02,  9.3577225e-03],\n",
      "       [ 2.5895271e-03, -2.9331330e-02, -1.8961012e-02, ...,\n",
      "        -8.0389353e-03,  3.2842189e-02,  4.3894034e-02]], dtype=float32), array([[ 0.01715658,  0.03835309, -0.02311821, ...,  0.00146474,\n",
      "         0.02993429, -0.05985384],\n",
      "       [ 0.00996143,  0.039217  , -0.03855301, ...,  0.00599566,\n",
      "         0.02722942, -0.06509776],\n",
      "       [ 0.01777726,  0.03919311, -0.01709837, ...,  0.00805702,\n",
      "         0.03988946, -0.05069073],\n",
      "       ...,\n",
      "       [ 0.05474931,  0.0075684 ,  0.00329455, ..., -0.01651684,\n",
      "         0.02397249,  0.00368039],\n",
      "       [ 0.0093503 ,  0.05022853, -0.02385841, ...,  0.02575599,\n",
      "         0.00786822, -0.03260205],\n",
      "       [ 0.01805054,  0.01337725,  0.00016697, ...,  0.01843987,\n",
      "         0.01374448,  0.00310114]], dtype=float32)]\n"
     ]
    }
   ],
   "source": [
    "print(f\"dense embedding:\\n{embeddings['dense_vecs']}\")\n",
    "print(f\"sparse embedding:\\n{embeddings['lexical_weights']}\")\n",
    "print(f\"multi-vector:\\n{embeddings['colbert_vecs']}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
